{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8b5fb82b",
   "metadata": {},
   "source": [
    "# SFT_GRPO_6: Putting it all together - Training Wordle"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2076ba3d",
   "metadata": {},
   "source": [
    "Import dependencies and setup the Predibase deployment for training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6edc92c2-fe1d-430a-a4c1-cbd50ca37272",
   "metadata": {
    "height": 217
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from predibase import (\n",
    "    Predibase,\n",
    "    GRPOConfig,\n",
    "    RewardFunctionsConfig,\n",
    "    RewardFunctionsRuntimeConfig,\n",
    "    SFTConfig,\n",
    "    SamplingParamsConfig,\n",
    ")\n",
    "from datasets import load_dataset\n",
    "from dotenv import load_dotenv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e18ec496-1344-4429-b050-a78274443224",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "load_dotenv(\"../.env\")\n",
    "pb = Predibase(api_token=os.environ[\"PREDIBASE_API_KEY\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53a0e509",
   "metadata": {},
   "source": [
    "Load the GRPO [wordle training dataset](https://huggingface.co/datasets/predibase/wordle-grpo) from Hugging Face:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "274d09ac-2af2-46b7-b136-8739a54ce266",
   "metadata": {
    "height": 217
   },
   "outputs": [],
   "source": [
    "# Load dataset from HuggingFace\n",
    "dataset = load_dataset(\"predibase/wordle-grpo\", split=\"train\")\n",
    "dataset = dataset.to_pandas()\n",
    "\n",
    "# Upload dataset in Predibase\n",
    "try:\n",
    "    dataset = pb.datasets.from_pandas_dataframe(\n",
    "        dataset,\n",
    "        name=\"wordle_grpo_data\"\n",
    "    )\n",
    "except Exception:\n",
    "    dataset = pb.datasets.get(\"wordle_grpo_data\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4e369b2",
   "metadata": {},
   "source": [
    "Create a training repo and load the Wordle reward functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "181d7552-4e35-4768-b018-e4eb55375441",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "# Uncomment the line below if running in your own environment - the repos is already setup for you here\n",
    "# Create repository in Predibase\n",
    "# repo = pb.repos.create(name=\"wordle\", exists_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9fafc31",
   "metadata": {},
   "source": [
    "<div style=\"background-color:#fff6ff; padding:13px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px\">\n",
    "<b>Note:</b> You can access the full code of the reward functions, stored in <code>reward_functions.py</code> by </b> 1) clicking on the <em>\"File\"</em> option on the top menu of the notebook and then 2) clicking on <em>\"Open\"</em>.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d81fa458-c29c-456d-b720-6fad654f90ea",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "# Import reward functions\n",
    "from reward_functions import (\n",
    "    guess_value,\n",
    "    output_format_check,\n",
    "    uses_previous_feedback,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7ee03ed",
   "metadata": {},
   "source": [
    "## Set up the training run"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb4506e0",
   "metadata": {},
   "source": [
    "<div style=\"background-color:#fff6ff; padding:13px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px\">\n",
    "<b>Note:</b> The following cell will not run on the learning platform. If you decide to run from your own computer, update the PREDIBASE_API_TOKEN environment variable with your own API key in the setup above. \n",
    "\n",
    "You can get free credits to try out Predibase at [this website](https://predibase.com/free-trial).\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6079648e-acb3-4c1e-a39a-057b46ae136e",
   "metadata": {
    "height": 387
   },
   "outputs": [],
   "source": [
    "# Create GRPO training run in Predibase by specifying the config, \n",
    "# dataset, repository and reward functions\n",
    "pb.finetuning.jobs.create(\n",
    "    config=GRPOConfig(\n",
    "        base_model=\"qwen2-5-7b-instruct\",\n",
    "        reward_fns=RewardFunctionsConfig(\n",
    "            runtime=RewardFunctionsRuntimeConfig(\n",
    "                packages=[\"pandas\"]\n",
    "            ),\n",
    "            functions={\n",
    "                \"output_format_check\": output_format_check,\n",
    "                \"uses_previous_feedback\": uses_previous_feedback,\n",
    "                \"guess_value\": guess_value,\n",
    "            }\n",
    "        ),\n",
    "        sampling_params=SamplingParamsConfig(max_tokens=4096),\n",
    "        num_generations=16\n",
    "    ),\n",
    "    dataset=dataset,\n",
    "    repo=\"wordle\",\n",
    "    description=\"Wordle GRPO\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53ce3479-ce7a-4803-bfec-5a07267af2be",
   "metadata": {},
   "source": [
    "## Try out SFT and SFT+GRPO on Predibase\n",
    "\n",
    "You can use the code below to setup an SFT training job in Predibase, and then use the resulting checkpoing as input for a GRPO run.\n",
    "\n",
    "This example uses a following [Wordle SFT dataset](https://huggingface.co/datasets/predibase/wordle-sft) available on Hugging Face. \n",
    "\n",
    "### SFT training on Predibase"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de2d2d4f",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "source": [
    "```python\n",
    "\n",
    "dataset = load_dataset(\"predibase/wordle-sft\", split=\"train\")\n",
    "dataset = dataset.to_pandas()\n",
    "\n",
    "# Upload dataset to Predibase\n",
    "dataset = pb.datasets.from_pandas_dataframe(dataset, name=\"wordle_sft_data\")\n",
    "\n",
    "# Create repository in Predibase\n",
    "repo = pb.repos.create(name=\"wordle\", exists_ok=True)\n",
    "\n",
    "# Create SFT training run in Predibase by specifying the config, dataset, repository and reward functions\n",
    "pb.finetuning.jobs.create(\n",
    "    config=SFTConfig(\n",
    "        base_model=\"qwen2-5-7b-instruct\",\n",
    "        epochs=10,\n",
    "        rank=64,\n",
    "        target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"down_proj\", \"up_proj\"],\n",
    "    ),\n",
    "    dataset=dataset,\n",
    "    repo=\"wordle\",\n",
    "    description=\"Wordle SFT, 10 epochs\"\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41d317d2",
   "metadata": {},
   "source": [
    "### SFT + GRPO training on Predibase"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0997722d",
   "metadata": {},
   "source": [
    "```python\n",
    "# Use the same dataset as the GRPO training run\n",
    "dataset = pb.datasets.get(\"wordle_grpo_data\")\n",
    "\n",
    "# Create GRPO training run in Predibase by specifying the config, dataset, repository and reward functions\n",
    "pb.finetuning.jobs.create(\n",
    "    config=GRPOConfig(\n",
    "        base_model=\"qwen2-5-7b-instruct\",\n",
    "        reward_fns=RewardFunctionsConfig(\n",
    "            runtime=RewardFunctionsRuntimeConfig(packages=[\"pandas\"]),\n",
    "            functions={\n",
    "                \"output_format_check\": output_format_check,\n",
    "                \"uses_previous_feedback\": uses_previous_feedback,\n",
    "                \"guess_value\": guess_value,\n",
    "            }\n",
    "        ),\n",
    "        epochs=3,\n",
    "        enable_early_stopping=False,\n",
    "        sampling_params=SamplingParamsConfig(max_tokens=4096),\n",
    "        num_generations=8\n",
    "    ),\n",
    "    continue_from_version=\"wordle/1\", # change \"1\" to the version number of the SFT training run in the repository\n",
    "    dataset=dataset,\n",
    "    repo=\"wordle\",\n",
    "    description=\"Wordle GRPO\"\n",
    ")\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
